Newer
Older
BlackoutClient / Assets / Best HTTP / Source / SecureProtocol / crypto / tls / TlsPskKeyExchange.cs
#if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR)
#pragma warning disable
using System;
using System.Collections;
using System.IO;

using BestHTTP.SecureProtocol.Org.BouncyCastle.Asn1.X509;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Parameters;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Security;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities;
using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.IO;

namespace BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Tls
{
    /// <summary>(D)TLS PSK key exchange (RFC 4279).</summary>
    public class TlsPskKeyExchange
        :   AbstractTlsKeyExchange
    {
        protected TlsPskIdentity mPskIdentity;
        protected TlsPskIdentityManager mPskIdentityManager;

        protected TlsDHVerifier mDHVerifier;
        protected DHParameters mDHParameters;
        protected int[] mNamedCurves;
        protected byte[] mClientECPointFormats, mServerECPointFormats;

        protected byte[] mPskIdentityHint = null;
        protected byte[] mPsk = null;

        protected DHPrivateKeyParameters mDHAgreePrivateKey = null;
        protected DHPublicKeyParameters mDHAgreePublicKey = null;

        protected ECPrivateKeyParameters mECAgreePrivateKey = null;
        protected ECPublicKeyParameters mECAgreePublicKey = null;

        protected AsymmetricKeyParameter mServerPublicKey = null;
        protected RsaKeyParameters mRsaServerPublicKey = null;
        protected TlsEncryptionCredentials mServerCredentials = null;
        protected byte[] mPremasterSecret;

        [Obsolete("Use constructor that takes a TlsDHVerifier")]
        public TlsPskKeyExchange(int keyExchange, IList supportedSignatureAlgorithms, TlsPskIdentity pskIdentity,
            TlsPskIdentityManager pskIdentityManager, DHParameters dhParameters, int[] namedCurves,
            byte[] clientECPointFormats, byte[] serverECPointFormats)
            :   this(keyExchange, supportedSignatureAlgorithms, pskIdentity, pskIdentityManager, new DefaultTlsDHVerifier(),
                    dhParameters, namedCurves, clientECPointFormats, serverECPointFormats)
        {
        }

        public TlsPskKeyExchange(int keyExchange, IList supportedSignatureAlgorithms, TlsPskIdentity pskIdentity,
            TlsPskIdentityManager pskIdentityManager, TlsDHVerifier dhVerifier, DHParameters dhParameters, int[] namedCurves,
            byte[] clientECPointFormats, byte[] serverECPointFormats)
            :   base(keyExchange, supportedSignatureAlgorithms)
        {
            switch (keyExchange)
            {
            case KeyExchangeAlgorithm.DHE_PSK:
            case KeyExchangeAlgorithm.ECDHE_PSK:
            case KeyExchangeAlgorithm.PSK:
            case KeyExchangeAlgorithm.RSA_PSK:
                break;
            default:
                throw new InvalidOperationException("unsupported key exchange algorithm");
            }

            this.mPskIdentity = pskIdentity;
            this.mPskIdentityManager = pskIdentityManager;
            this.mDHVerifier = dhVerifier;
            this.mDHParameters = dhParameters;
            this.mNamedCurves = namedCurves;
            this.mClientECPointFormats = clientECPointFormats;
            this.mServerECPointFormats = serverECPointFormats;
        }

        public override void SkipServerCredentials()
        {
            if (mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
                throw new TlsFatalAlert(AlertDescription.unexpected_message);
        }

        public override void ProcessServerCredentials(TlsCredentials serverCredentials)
        {
            if (!(serverCredentials is TlsEncryptionCredentials))
                throw new TlsFatalAlert(AlertDescription.internal_error);

            ProcessServerCertificate(serverCredentials.Certificate);

            this.mServerCredentials = (TlsEncryptionCredentials)serverCredentials;
        }

        public override byte[] GenerateServerKeyExchange()
        {
            this.mPskIdentityHint = mPskIdentityManager.GetHint();

            if (this.mPskIdentityHint == null && !RequiresServerKeyExchange)
                return null;

            MemoryStream buf = new MemoryStream();

            if (this.mPskIdentityHint == null)
            {
                TlsUtilities.WriteOpaque16(TlsUtilities.EmptyBytes, buf);
            }
            else
            {
                TlsUtilities.WriteOpaque16(this.mPskIdentityHint, buf);
            }

            if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
            {
                if (this.mDHParameters == null)
                    throw new TlsFatalAlert(AlertDescription.internal_error);

                this.mDHAgreePrivateKey = TlsDHUtilities.GenerateEphemeralServerKeyExchange(mContext.SecureRandom,
                    this.mDHParameters, buf);
            }
            else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
            {
                this.mECAgreePrivateKey = TlsEccUtilities.GenerateEphemeralServerKeyExchange(mContext.SecureRandom,
                    mNamedCurves, mClientECPointFormats, buf);
            }

            return buf.ToArray();
        }

        public override void ProcessServerCertificate(Certificate serverCertificate)
        {
            if (mKeyExchange != KeyExchangeAlgorithm.RSA_PSK)
                throw new TlsFatalAlert(AlertDescription.unexpected_message);
            if (serverCertificate.IsEmpty)
                throw new TlsFatalAlert(AlertDescription.bad_certificate);

            X509CertificateStructure x509Cert = serverCertificate.GetCertificateAt(0);

            SubjectPublicKeyInfo keyInfo = x509Cert.SubjectPublicKeyInfo;
            try
            {
                this.mServerPublicKey = PublicKeyFactory.CreateKey(keyInfo);
            }
            catch (Exception e)
            {
                throw new TlsFatalAlert(AlertDescription.unsupported_certificate, e);
            }

            // Sanity check the PublicKeyFactory
            if (this.mServerPublicKey.IsPrivate)
                throw new TlsFatalAlert(AlertDescription.internal_error);

            this.mRsaServerPublicKey = ValidateRsaPublicKey((RsaKeyParameters)this.mServerPublicKey);

            TlsUtilities.ValidateKeyUsage(x509Cert, KeyUsage.KeyEncipherment);

            base.ProcessServerCertificate(serverCertificate);
        }

        public override bool RequiresServerKeyExchange
        {
            get
            {
                switch (mKeyExchange)
                {
                case KeyExchangeAlgorithm.DHE_PSK:
                case KeyExchangeAlgorithm.ECDHE_PSK:
                    return true;
                default:
                    return false;
                }
            }
        }

        public override void ProcessServerKeyExchange(Stream input)
        {
            this.mPskIdentityHint = TlsUtilities.ReadOpaque16(input);

            if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
            {
                this.mDHParameters = TlsDHUtilities.ReceiveDHParameters(mDHVerifier, input);
                this.mDHAgreePublicKey = new DHPublicKeyParameters(TlsDHUtilities.ReadDHParameter(input), mDHParameters);
            }
            else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
            {
                ECDomainParameters ecParams = TlsEccUtilities.ReadECParameters(mNamedCurves, mClientECPointFormats, input);

                byte[] point = TlsUtilities.ReadOpaque8(input);

                this.mECAgreePublicKey = TlsEccUtilities.ValidateECPublicKey(TlsEccUtilities.DeserializeECPublicKey(
                    mClientECPointFormats, ecParams, point));
            }
        }

        public override void ValidateCertificateRequest(CertificateRequest certificateRequest)
        {
            throw new TlsFatalAlert(AlertDescription.unexpected_message);
        }

        public override void ProcessClientCredentials(TlsCredentials clientCredentials)
        {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }

        public override void GenerateClientKeyExchange(Stream output)
        {
            if (mPskIdentityHint == null)
            {
                mPskIdentity.SkipIdentityHint();
            }
            else
            {
                mPskIdentity.NotifyIdentityHint(mPskIdentityHint);
            }

            byte[] psk_identity = mPskIdentity.GetPskIdentity();
            if (psk_identity == null)
                throw new TlsFatalAlert(AlertDescription.internal_error);

            this.mPsk = mPskIdentity.GetPsk();
            if (mPsk == null)
                throw new TlsFatalAlert(AlertDescription.internal_error);

            TlsUtilities.WriteOpaque16(psk_identity, output);

            mContext.SecurityParameters.pskIdentity = psk_identity;

            if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
            {
                this.mDHAgreePrivateKey = TlsDHUtilities.GenerateEphemeralClientKeyExchange(mContext.SecureRandom,
                    mDHParameters, output);
            }
            else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
            {
                this.mECAgreePrivateKey = TlsEccUtilities.GenerateEphemeralClientKeyExchange(mContext.SecureRandom,
                    mServerECPointFormats, mECAgreePublicKey.Parameters, output);
            }
            else if (this.mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
            {
                this.mPremasterSecret = TlsRsaUtilities.GenerateEncryptedPreMasterSecret(mContext,
                    this.mRsaServerPublicKey, output);
            }
        }

        public override void ProcessClientKeyExchange(Stream input)
        {
            byte[] psk_identity = TlsUtilities.ReadOpaque16(input);

            this.mPsk = mPskIdentityManager.GetPsk(psk_identity);
            if (mPsk == null)
                throw new TlsFatalAlert(AlertDescription.unknown_psk_identity);

            mContext.SecurityParameters.pskIdentity = psk_identity;

            if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
            {
                this.mDHAgreePublicKey = new DHPublicKeyParameters(TlsDHUtilities.ReadDHParameter(input), mDHParameters);
            }
            else if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
            {
                byte[] point = TlsUtilities.ReadOpaque8(input);

                ECDomainParameters curve_params = this.mECAgreePrivateKey.Parameters;

                this.mECAgreePublicKey = TlsEccUtilities.ValidateECPublicKey(TlsEccUtilities.DeserializeECPublicKey(
                    mServerECPointFormats, curve_params, point));
            }
            else if (this.mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
            {
                byte[] encryptedPreMasterSecret;
                if (TlsUtilities.IsSsl(mContext))
                {
                    // TODO Do any SSLv3 clients actually include the length?
                    encryptedPreMasterSecret = Streams.ReadAll(input);
                }
                else
                {
                    encryptedPreMasterSecret = TlsUtilities.ReadOpaque16(input);
                }

                this.mPremasterSecret = mServerCredentials.DecryptPreMasterSecret(encryptedPreMasterSecret);
            }
        }

        public override byte[] GeneratePremasterSecret()
        {
            byte[] other_secret = GenerateOtherSecret(mPsk.Length);

            MemoryStream buf = new MemoryStream(4 + other_secret.Length + mPsk.Length);
            TlsUtilities.WriteOpaque16(other_secret, buf);
            TlsUtilities.WriteOpaque16(mPsk, buf);

            Arrays.Fill(mPsk, (byte)0);
            this.mPsk = null;

            return buf.ToArray();
        }

        protected virtual byte[] GenerateOtherSecret(int pskLength)
        {
            if (this.mKeyExchange == KeyExchangeAlgorithm.DHE_PSK)
            {
                if (mDHAgreePrivateKey != null)
                {
                    return TlsDHUtilities.CalculateDHBasicAgreement(mDHAgreePublicKey, mDHAgreePrivateKey);
                }

                throw new TlsFatalAlert(AlertDescription.internal_error);
            }

            if (this.mKeyExchange == KeyExchangeAlgorithm.ECDHE_PSK)
            {
                if (mECAgreePrivateKey != null)
                {
                    return TlsEccUtilities.CalculateECDHBasicAgreement(mECAgreePublicKey, mECAgreePrivateKey);
                }

                throw new TlsFatalAlert(AlertDescription.internal_error);
            }

            if (this.mKeyExchange == KeyExchangeAlgorithm.RSA_PSK)
            {
                return this.mPremasterSecret;
            }

            return new byte[pskLength];
        }

        protected virtual RsaKeyParameters ValidateRsaPublicKey(RsaKeyParameters key)
        {
            // TODO What is the minimum bit length required?
            // key.Modulus.BitLength;

            if (!key.Exponent.IsProbablePrime(2))
                throw new TlsFatalAlert(AlertDescription.illegal_parameter);

            return key;
        }
    }
}
#pragma warning restore
#endif